Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: DTVB with Swizzling (tensorB) #1562

Merged
merged 4 commits into from
Feb 7, 2025
Merged

Conversation

solaslin
Copy link
Contributor

@solaslin solaslin commented Jan 16, 2025

Resolved SWDEV-509997

  • Code-gen of Swizzled (SwizzleTensorB): Requires TN/NN and DTVB
  • FP16, BF16 MFMA_16x16x16_x1. Supported useBias and SAV
  • Supported other data-type which the MI_M/N is also 16 (Support other types for Swizzling #1574)
  • Included supporting for edge tile and tail-loop
  • Implemented for SwizzledB when WaveGroups[0] > 1 (M-Dim)
  • Supported arbitrary M & K for SwizzledB (padding)
  • pytests for SwizzledB

[gw2] [ 25%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/swizzleB.yaml]
[gw3] [ 50%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/swizzleA.yaml]
[gw0] [ 75%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtl.yaml]
[gw1] [100%] PASSED Tensile/Tests/common/test_config.py::test_config[Tensile/Tests/self_test/dtv.yaml]
=== 4 passed, 5 warnings in 2335.54s (0:38:55) ===
py310: OK (2487.54=setup[10.66]+cmd[0.39,140.81,2335.68] seconds)
congratulations :) (2487.61 seconds)

@solaslin solaslin added the noCI Disable testing on supported CI systems: math libraries CI has this feature enabled.. label Jan 16, 2025
@solaslin solaslin self-assigned this Jan 16, 2025
@solaslin solaslin force-pushed the swizzledB branch 4 times, most recently from c2565ef to 84c9d56 Compare January 23, 2025 02:41
@solaslin solaslin added enhancement New feature or request and removed noCI Disable testing on supported CI systems: math libraries CI has this feature enabled.. labels Jan 23, 2025
@solaslin solaslin marked this pull request as ready for review January 23, 2025 03:06
@solaslin solaslin added the gfx94x Run CI on gfx94x label Feb 3, 2025
Comment on lines 216 to 223
depthUDiv = kernel["DepthU"]
# Swizzled for A, TODO- check for SwizzleTensorB
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_M") if (tP["isSwizzled"] and tc == 'A') else "%s"%kernel["DepthU"]
#
# swizzle
if (tP["isSwizzled"] and tc == 'A'):
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_M")
elif (tP["isSwizzled"] and tc == 'B'):
depthUDiv = "%s%s"%(kernel["DepthU"], "*MI_N")

gsuOffsetStr = "gsuOffset = DepthU*bpeGR*GSUSumIdx"
Copy link
Contributor Author

@solaslin solaslin Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, all changes about renaming from MI_M to (MI_MN or MI_MorN) is for both A/B

Comment on lines 3168 to 3173

if tP["isA"] and tP["isSwizzled"]:
module.addModuleAsFlatItems(self.alignTo("StrideA0I", "StrideA0I", tP["swizzleK"]))
if tP["isSwizzled"]:
# "StrideA0I" or "StrideB1J"
strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]])
module.addModuleAsFlatItems(self.alignTo(strideName, strideName, tP["swizzleK"]))

Copy link
Contributor Author

@solaslin solaslin Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strideName = "Stride%s%s"%(tc,self.states.indexChars[tP["idx"]])
this would be "StrideA0I" or "StrideB1J".
Thank Jimmy for providing this info.

Comment on lines 3414 to 3425
module.add(SAddU32(sgpr(tmpSgpr), sgpr("SizesSum"), swizzleStrideVal-1))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ: numKr = DimK / %s"%swizzleStrideVal))
module.add(SLShiftRightB32(dst=sgpr(tmpSgpr), shiftHex=hex(log2(swizzleStrideVal)), src=sgpr(tmpSgpr), comment="SWZ-%s: numKr = DimK / %s"%(tc, swizzleStrideVal)))
WvG_M = kernel["MIWaveGroup"][0]
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) %= MIWG[0]"))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ: wave_id (along_M) *= numKr"))
if tP["isA"]:
module.add(VAndB32(dst=vgpr(qReg), src0=hex(WvG_M-1), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) mod MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_M) *= numKr"%tc))
elif tP["isB"]:
# NB:
# Calc of w_id is: /= MIWG[0], not %= MIWG[1]
module.add(VLShiftRightB32(dst=vgpr(qReg), shiftHex=log2(WvG_M), src=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) /= MIWG[0]"%tc))
module.add(VMulU32U24(dst=vgpr(qReg), src0=sgpr(tmpSgpr), src1=vgpr(qReg), comment="SWZ-%s: wave_id (along_N) *= numKr"%tc))
elif isDTVAB:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be the most important note:
For swizzling A: the order of wave is wave_id = wave_id % MI_WaveG[0]
For swizzling B: the order of wave is wave_id =wave_id / MI_WaveG[0] (not wave_id % MI_WaveG[1])

Comment on lines 671 to 680
{
const auto k = desc.sizes()[0];
const auto m = desc.sizes()[1];
const auto b = desc.sizes()[2];
const auto swizzleK = miK * packK;
const auto paddedM = (m + miM - 1) / miM * miM;
const auto paddedK = (k + swizzleK - 1) / swizzleK * swizzleK;
return paddedM * paddedK * b;
// TODO: currently [0][1] = k, (m or n) is based on TN, need to make this generic in the future
const auto k = desc.sizes()[0];
const auto m_n = desc.sizes()[1];
const auto b = desc.sizes()[2];
const auto swizzleK = miK * packK;
const auto paddedM_N = (m_n + miM_N - 1) / miM_N * miM_N;
const auto paddedK = (k + swizzleK - 1) / swizzleK * swizzleK;
return paddedM_N * paddedK * b;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the changes in host code is to make clear that "M" should be "M or N"

Comment on lines +21 to +28

BenchmarkProblems:
########################################
# HHS TN DTVB + SWIZZLED_B + BIAS + Activation + SAV
########################################
-
- # ProblemType
OperationType: GEMM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added similar tests for SwizzleB from SwizzleA

Copy link
Contributor Author

@solaslin solaslin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added self reviews

@geotseng-amd geotseng-amd self-requested a review February 7, 2025 05:51
@solaslin
Copy link
Contributor Author

solaslin commented Feb 7, 2025

The Failures in CI are the same as other PSs, which are "Solution not found" issues and have nothing to do with this PR. Merging.

@solaslin solaslin merged commit b390458 into ROCm:develop Feb 7, 2025
9 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request gfx94x Run CI on gfx94x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants